import torch
from tqdm import tqdm
import random
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import sys

def normalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images - mean[None, :, None, None]
    images = images / std[None, :, None, None]
    return images

def denormalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images * std[None, :, None, None]
    images = images + mean[None, :, None, None]
    return images


class Attacker:

    def __init__(self, args, model, targets, device='cuda:0', is_rtp=False):

        self.args = args
        self.model = model
        self.device = device
        self.is_rtp = is_rtp

        self.targets = targets
        self.num_targets = len(targets)

        self.loss_buffer = []

        # freeze and set to eval model:
        self.model.eval()
        self.model.requires_grad_(False)



    def targeted_attack_B2H(self, img, batch_size = 1, num_iter=2000, alpha=1/255, epsilon = 128/255, ours = True, before = True):
        batch_size = self.args.batch_size
        print('>>> batch_size:', batch_size)

        adv_noise = torch.rand_like(img).to(self.device) * 2 * epsilon - epsilon
        x = denormalize(img).clone().to(self.device)
        adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data

        adv_noise.requires_grad_(True)
        adv_noise.retain_grad()

        file = open("harmful_corpus/derogatory_corpus.csv", "r") 
        data = list(csv.reader(file, delimiter=","))
        file.close()
        targets = []
        num = len(data)
        for i in range(num):
            targets.append(data[i][0])
        self.targets = targets
        
        file = open("harmful_corpus/benign_sentences.csv", "r") 
        data = list(csv.reader(file, delimiter=","))
        file.close()
        targets = []
        num = len(data)
        for i in range(num):
            targets.append(data[i][0])
        self.targets_1 = targets
        
        file = open("harmful_corpus/harmful_words.csv", "r") 
        data = list(csv.reader(file, delimiter=","))
        file.close()
        targets = []
        num = len(data)
        for i in range(num):
            targets.append(data[i][0])
        self.targets_2 = targets
            
        question_adv = ''

        for step in range(1):
            class_n = step
            print('######### step = %d ##########' % step)
            adv_noise = torch.rand_like(x).cuda() * 2 * epsilon - epsilon
            adv_noise = adv_noise.cuda()
            adv_noise.requires_grad_(True)
            adv_noise.retain_grad()
            
            for t in tqdm(range(num_iter + 1)):



                x_org = denormalize(img).clone().to(self.device)
                x_adv = x_org + adv_noise
                x_org = normalize(x_org)
                x_adv = normalize(x_adv)

                
                batch_targets = random.sample(self.targets, batch_size)
                th = self.args.th
                many_shot_num = 1
                
                if self.args.ours:
                    rand_prob = random.random()
                else :
                    rand_prob = 2
                
                if th < rand_prob : # label 1 = label 2
                    batch_targets = random.sample(self.targets, many_shot_num)
                    batch_targets_2 = batch_targets
                else : # label 1 != label 2
                    batch_targets = random.sample(self.targets_1, many_shot_num)
                    batch_targets_2 = random.sample(self.targets_2, min(50, len(self.targets_2)))
                    batch_targets_2 = batch_targets_2 * 20 * many_shot_num 
                    batch_targets_2 = [" ".join(batch_targets_2)]

                samples = {
                    'image': x_adv,
                    'text_input': [''] * batch_size,
                    'text_output': batch_targets,
                    'text_output_2': batch_targets_2,
                }
                target_loss = self.model(samples)['loss']
                target_loss.backward()

                adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(-epsilon, epsilon)
                adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.to(self.device).data

                adv_noise.grad.zero_()
                self.model.zero_grad()

                self.loss_buffer.append(target_loss.item())


                if t % 1000 == 0:
                    print('######### Output - Iter = %d ##########' % t)
                    x_org = x 
                    x_adv = x + adv_noise 
                    x_org = normalize(x_org)
                    x_adv = normalize(x_adv)

                    with torch.no_grad():
                        print('>>> Sample Outputs')
                        response_adv = self.model.generate({"image": x_adv, "prompt": f'{question_adv}'},
                                                use_nucleus_sampling=True, top_p=0.9, temperature=1)
                        print(response_adv)
                    adv_img_prompt = denormalize(x_adv).detach().cpu()
                    adv_img_prompt = adv_img_prompt.squeeze(0)
                    save_image(adv_img_prompt, '%s/class%d_iter%d.png' % (self.args.save_dir, class_n, t))
                    torch.save(x_adv, '%s/class%d_iter%d.pt' % (self.args.save_dir, class_n, t))
                        
                    if t == 0:
                        adv_img_prompt = denormalize(x_org).detach().cpu()
                        adv_img_prompt = adv_img_prompt.squeeze(0)
                        save_image(adv_img_prompt, '%s/class%d_iter%d.png' % (self.args.save_dir, class_n, t))
                        torch.save(x_org, '%s/class%d_iter%d.pt' % (self.args.save_dir, class_n, t))                        

        return None